#!/usr/bin/python

################################################################
### Apply language models.
################################################################

import sys,os,re,glob,math,glob,signal,traceback,sqlite3
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from scipy.ndimage import interpolation,measurements
from pylab import *
from optparse import OptionParser
from multiprocessing import Pool
import ocrolib
from ocrolib import number_of_processors,die,docproc
from scipy import stats
from ocrolib import dbhelper,ocroio
from ocrolib import ligatures,rect_union,Record
import multiprocessing
import ocrofst
from ocrofst.lattice import Lattice

signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

import argparse
parser = argparse.ArgumentParser(description = """
%prog [-l langmod] *.fst
""")

parser.add_argument("-b","--beam",type=int,default=10000)
parser.add_argument("-F","--usefst",action="store_true",help="use fst instead of lattice file")
parser.add_argument("-l","--langmod",help="language model",default="default.fst")

parser.add_argument("-s","--suffix",help="output suffix for writing result",default="")
parser.add_argument("-O","--overwrite",help="overwrite outputs",action="store_true")

parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count(),help="number of parallel processes to use")

parser.add_argument("-D","--Display",help="display",action="store_true")
parser.add_argument("--debug_line",action="store_true")
parser.add_argument("--debug_rawalign",action="store_true")
parser.add_argument("--debug_aligned",action="store_true")
parser.add_argument("--debug_select",action="store_true")
parser.add_argument('--dgrid',default=8,help="grid size for display")
parser.add_argument("files",default=[],nargs='*',help="input lines)")
args = parser.parse_args()

def concat(l):
    result = []
    for x in l: result += x
    return result

args.files = concat([glob.glob(file) for file in args.files])
if args.usefst:
    args.files = concat([glob.glob(file+"/????/??????.fst")
                              if os.path.isdir(file) else [file] for file in args.files])
else:
    args.files = concat([glob.glob(file+"/????/??????.lattice")
                              if os.path.isdir(file) else [file] for file in args.files])


debug = 0

lmodel = ocrolib.ocropus_find_file(args.langmod)
langmod = ocrofst.OcroFST()
print "# loading",lmodel
langmod.load(lmodel)


def align1(job):
    fname = job
    loaded = None
    if args.usefst:
        fst_file = ocrolib.fvariant(fname,"fst")
        assert os.path.exists(fst_file),fst_file+": does not exist"
        fst = ocrofst.OcroFST()
        fst.load(fst_file)
        fname = fst_file
    else:
        lattice_file = ocrolib.fvariant(fname,"lattice")
        assert os.path.exists(lattice_file),lattice_file+": does not exist"
        gr = Lattice()
        with open(lattice_file) as stream:
            gr.loadLattice(stream)
        fst = gr.getLatticeAsFST()
        fname = lattice_file

    result = ocrofst.beam_search(fst,langmod,args.beam)
    if args.debug_line:
        print "line",i,len(result[0]),sum(result[4])
    if len(result[0])<=1: 
        print fname,": FAILED"
        return

    v1,v2,ins,outs,costs = result

    if args.debug_rawalign:
        for i in range(len(v1)):
            print "raw-align %3d [%3d %3d] (%3d %3d) %6.2f"%\
                (i,v1[i],v2[i],ins[i]>>16,ins[i]&0xffff,costs[i]),unichr(outs[i])

    sresult = []
    scosts = []
    segs = []

    n = len(ins)
    i = 1
    while i<n:
        if outs[i]==ord(" "):
            sresult.append(" ")
            scosts.append(costs[i])
            segs.append((0,0))
            i += 1
            continue
        # pick up ligatures indicated by the recognizer (multiple sequential 
        # output characters with the same input segments)
        j = i+1
        while j<n and ins[j]==ins[i]:
            j += 1
        cls = "".join([unichr(c) for c in outs[i:j]])
        sresult.append(cls)
        scosts.append(sum(costs[i:j]))
        start = (ins[i]>>16)
        end = (ins[i]&0xffff)
        segs.append((start,end))
        i = j

    # read the raw segmentation
    rseg_file = ocrolib.fvariant(fname,"rseg")
    if not os.path.exists(rseg_file):
        print "*"+rseg_file,": NOT FOUND"
        return
    rseg = ocroio.read_line_segmentation(rseg_file)
    rseg_boxes = docproc.seg_boxes(rseg)

    # Now run through the segments and create a table that maps rseg
    # labels to the corresponding output element.

    assert len(sresult)==len(segs)
    assert len(scosts)==len(segs)

    aligned = "".join(sresult)
    aligned = re.sub(r'\0','',aligned)
    print fname,":",aligned

    bboxes = []

    rmap = zeros(amax(rseg)+1,'i')
    for i in range(1,len(segs)):
        start,end = segs[i]
        if start==0 or end==0: continue
        rmap[start:end+1] = i
        bboxes.append(rect_union(rseg_boxes[start:end+1]))
    assert rmap[0]==0

    cseg = zeros(rseg.shape,'i')
    for i in range(cseg.shape[0]):
        for j in range(cseg.shape[1]):
            cseg[i,j] = rmap[rseg[i,j]]

    assert len(segs)==len(sresult) 
    assert len(segs)==len(scosts)

    assert amin(cseg)==0,"amin(cseg)!=0 (%d,%d)"%(amin(cseg),amax(cseg))

    # first work on the list output; here, one list element should
    # correspond to each cost
    result = sresult
    assert len(result)==len(scosts),\
        "output length %d differs from cost length %d"%(len(result),len(costs))
    assert amax(cseg)<len(result),\
        "amax(cseg) %d not consistent with output length %d"%(amax(cseg),len(output_l))

    # if there are spaces at the end, trim them (since they will never have a corresponding cseg)

    while len(result)>0 and result[-1]==" ":
        result = result[:-1]
        costs = costs[:-1]

    cseg_file = ocrolib.fvariant(fname,"cseg",args.suffix)
    if not args.overwrite:
        if os.path.exists(cseg_file): die("%s: already exists",cseg_file)
    ocrolib.write_line_segmentation(cseg_file,cseg)
    # aligned text line
    ocrolib.write_text(ocrolib.fvariant(fname,"txt",args.suffix),aligned)
    # per character costs
    with ocrolib.fopen(fname,"costs",args.suffix,mode="w") as stream:
        for i in range(len(costs)):
            stream.write("%d %g\n"%(i,costs[i]))
    # true ground truth (best-matching line in the original transcription)
    # ocrolib.write_text(ocrolib.fvariant(fname,"tru",args.suffix),bestline)

def safe_align1(job):
    try:
        align1(job)
    except:
        traceback.print_exc()

jobs = args.files

print "got",len(jobs),"jobs"

if args.parallel<2:
    for arg in jobs: align1(arg)
else:
    pool = Pool(processes=args.parallel)
    result = []
    for r in pool.imap(safe_align1,jobs):
        result.append(r)
        if len(result)%100==0: print "==========",len(result),len(jobs),"=========="
